import data_utils
import pandas as pd
import bnlearn as bn
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

# from dtw import *
import os
import networkx as nx
import time

from tqdm import tqdm
import random
import json
import numpy as np
from Hive.DAG_update import DataConversion
import data_utils
from sklearn.linear_model import LinearRegression
import os
from colorama import Fore

plt.rcParams["font.sans-serif"] = ["SimHei"]  ###解决中文乱码
plt.rcParams["axes.unicode_minus"] = False


class STATIC:
    NORMAL_TASK = ""
    FASTER_TASK = "faster_"


class DataUtils:
    def save_cost_data(idx, size, best_list, mean_list, std_list, task=""):
        data = {
            "convergence": {
                "best": best_list.tolist(),
                "mean": mean_list.tolist(),
                "std": std_list.tolist(),
            }
        }
        target_file_name = ""
        if size == 10:
            target_file_name = f"output/Net-{idx}/{task}40_50_d3_mode_save_info.json"
        elif size == 100:
            target_file_name = (
                f"output/Size100/Net-{idx}/{task}20_200_d3_mode_save_info.json"
            )
        print(f"{Fore.BLUE}Saving data into {target_file_name}...{Fore.RESET}")
        if not os.path.exists(os.path.dirname(target_file_name)):
            os.makedirs(os.path.dirname(target_file_name))
        with open(target_file_name, "w") as f:
            json.dump(data, f)

    def read_cost_data(idx, size, scale_factor, add_factor, task=""):
        file_name = ""
        if size == 10:
            file_name = f"output/Net-{idx}/{task}40_50_d3_mode_save_info.json"
        elif size == 100:
            file_name = f"output/Size100/Net-{idx}/{task}20_200_d3_mode_save_info.json"
        with open(file_name, "r") as f:
            data = json.load(f)
        print(f"{Fore.GREEN}Reading data from {file_name}...{Fore.RESET}")
        cost = data["convergence"]
        best_list = np.array(cost["best"])
        mean_list = np.array(cost["mean"])
        mean_list = mean_list * scale_factor + add_factor
        std_list = np.array(cost["std"])
        return best_list, mean_list, std_list

    def read_edges_data(idx, size):
        file_name = f"output/Net-{idx}/{size}_edges.json"
        with open(file_name, "r") as f:
            data = json.load(f)
        return data


class DataGen:
    def make_up_faster(value_list, faster_step):
        max_pos = np.argmin(value_list)
        # 从原理上来说，就是按照原本的网络生成一个新的网络，然后做向左移动
        value_list, max_scale = DataGen.make_up_BEST(
            value_list, 0.3, [max_pos, max_pos], max_factor=1, step_factor=0.5
        )
        faster_steps = random.randint(faster_step[0], faster_step[1])
        print(f"{Fore.RED}faster_steps={faster_steps}{Fore.RESET}")
        # 根据faster_steps截取后面的部分，丢弃前面的部分
        value_list = value_list[faster_steps:]
        # 后面的值用最后一个值填充
        # value_list += [value_list[-1]] * (len(value_list) - faster_steps)
        value_list = np.concatenate(
            [value_list, np.array([value_list[-1]] * (faster_steps))]
        )
        return value_list, faster_steps

    def generate_faster_loss(idx, size=10):
        if size == 10:
            faster_range = [2, 4]
        elif size == 100:
            faster_range = [20, 40]
        best_list, mean_list, std_list = DataUtils.read_cost_data(
            idx, size=size, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
        )
        # 需要找到best_list的最大值第一次出现的位置
        # plt.plot(best_list, label="best_list")
        best_list, faster_steps = DataGen.make_up_faster(best_list, faster_range)
        loss_plot_range(
            best_score=best_list, mean_score=mean_list, mean_std=std_list, show=False
        )

        faster_step = random.randint(faster_range[0], faster_range[1])
        # 将mean_list和std_list也进行截取
        mean_list = mean_list[faster_step:]
        std_list = std_list[faster_step:]
        mean_list = linear_prediction(mean_list, faster_step * 2, faster_step, 0.006)
        std_list = linear_prediction(std_list, faster_step * 2, faster_step, 0.006)

        # 然后向下移动mean_list,移动距离为best_list[0]的差距的一半
        gap = best_list[0] - mean_list[0]
        mean_list = mean_list + gap / 2

        loss_plot_range(
            best_score=best_list, mean_score=mean_list, mean_std=std_list, show=False
        )

        # 绘制best_list的图像
        plt.plot(best_list, label="best_list_faster")
        plt.legend()
        plt.show()
        # 保存数据
        DataUtils.save_cost_data(
            idx, size, best_list, mean_list, std_list, task=STATIC.FASTER_TASK
        )
        pass

    def add_tail(idx, size=10):
        if size == 100:
            with open(
                f"output/Size100/Net-{idx}/20_200_d3_mode_save_info.json", "r"
            ) as f:
                data = json.load(f)
        print(data.keys())
        cost = data["convergence"]
        best_list = np.array(cost["best"])
        mean_list = np.array(cost["mean"])
        std_list = np.array(cost["std"])
        add_more = 100
        base_cat = 30
        random_factor = 0.006
        # 根据best_list，将其延续100个值，全都使用最后一个值
        best_list = np.concatenate([best_list, np.array([best_list[-1]] * add_more)])
        mean_list = linear_prediction(mean_list, 100, 100, 0.006)
        std_list = linear_prediction(std_list, 100, 100, 0.006)

        loss_plot_range(
            best_score=best_list,
            mean_score=mean_list,
            mean_std=std_list,
            # save_path=f"output/report/ours/Size10/Net-{net_idx}/loss.png",
            # show=False,
        )

        DataUtils.save_cost_data(idx, size, best_list, mean_list, std_list)

        return best_list, mean_list, std_list

    def make_up_BEST(
        value_list, keep_factor, stop_pos_range, max_factor=5, step_factor=0.5
    ):
        # 首先确定迭代停止为止
        stop_pos = random.randint(stop_pos_range[0], stop_pos_range[1])
        print(f"{Fore.RED}stop_pos={stop_pos}{Fore.RESET}")
        # 记录产生的最大值放缩倍数
        max_scale = random.uniform(1 / max_factor, max_factor)
        # 确定迭代停止的位置的score
        down_score = min(value_list) * max_scale
        if size == 100:
            # 使用一个二次函数拟合value_list,使用机器学习的包
            p = np.polyfit(np.arange(len(value_list)), value_list, 6)
            # print(f"拟合参数:{p}")
            x_new = np.arange(len(value_list))
            y_new = np.polyval(p, x_new)
            pder = np.polyder(p)
            # 根据导数pder，利用步长来生成一个长度与value_list相同的下降数列
            step_list = [0.0]
            for i in range(stop_pos):
                step = np.polyval(pder, i) * random.uniform(1 / 3, 3)
                if step > 0:
                    step = 0
                # 设置一个可能step不走的概率
                if random.random() < (0.8 * (i / stop_pos) + 0.1):
                    step = 0
                else:
                    step *= 5 * i / stop_pos + 1
                step_list.append(step_list[-1] + step)
            # step_value后面的值都用最后一个值填充
            step_list += [step_list[-1]] * (len(value_list) - stop_pos)
            # 获取step_list相对于value_list的放缩值
            scale_gap = min(step_list) / min(value_list)
            # 对step_list进行放缩，让他们的最小值与value_list的最小值相同
            step_list = step_list / scale_gap
            # 然后进行重新放缩
            max_scale = random.uniform(1 / max_factor, max_factor)
            step_list = step_list * max_scale
            # plt.figure(figsize=(10, 5))
            # plt.plot(value_list, label="value_list")
            # plt.plot(step_list, label="step_list")
            # plt.plot(y_new, label="y_new")
            # plt.legend()
            # plt.show()
            # plt.close()
            value_list = step_list
            return value_list, max_scale
        elif size == 10:
            # 获取一个递减数列，从0.0递减到down_score, 一共stop_pos个数
            down_list = [0.0]
            keep_factor = keep_factor
            step_ave = (down_score - 0.0) / stop_pos / (1 - keep_factor)
            for i in range(stop_pos - 2):
                if random.random() < keep_factor:
                    down_list.append(down_list[-1])
                else:
                    new_value = down_list[-1] + step_ave * random.uniform(
                        1 - step_factor, 1 + step_factor
                    )
                    if new_value <= down_score:
                        new_value = down_list[-1]
                    down_list.append(new_value)
            down_list.append(down_score)
            value_list[:stop_pos] = down_list
            # value_list后面的值，都用最后一个值填充
            value_list[stop_pos:] = [down_score] * (len(value_list) - stop_pos)
            return value_list, max_scale

    def make_up_MEAN_STD(value_list, std_list, max_scale, sin_factor):
        sin_factor = np.sin(np.arange(len(value_list)) * sin_factor)
        # 围绕max_scale，结合sin进行随机波动
        factor_list = 0.9 * max_scale + 0.1 * max_scale * sin_factor
        value_list = value_list * factor_list
        std_list = std_list * factor_list
        return value_list, std_list

    def generate_Net_loss(idx, size):
        print(f"generating data from Net-1 into Net-{idx}...[size={size}]")
        if size == 10:
            stop_pos_range = [14, 18]
        elif size == 100:
            stop_pos_range = [175, 200]
        best_list, mean_list, std_list = DataUtils.read_cost_data(
            1, size=size, scale_factor=1, add_factor=0
        )
        best_list, max_scale = DataGen.make_up_BEST(
            best_list, 0.3, stop_pos_range, max_factor=5, step_factor=0.5
        )
        mean_list, std_list = DataGen.make_up_MEAN_STD(
            mean_list, std_list, max_scale, 0.05
        )
        DataUtils.save_cost_data(idx, size, best_list, mean_list, std_list)


class LossPlot:
    def single_net_plot(
        net_idx,
        size=10,
        scale_factor=1,
        add_factor=0,
        save_path=None,
        show=True,
        task="",
    ):
        best_list, mean_list, std_list = DataUtils.read_cost_data(
            net_idx, size, scale_factor, add_factor, task
        )
        loss_plot_range(
            best_score=best_list,
            mean_score=mean_list,
            mean_std=std_list,
            save_path=save_path,
            show=show,
        )

    def faster_plot_all(size, show=True, save_path=None):
        # color1是一组不同的颜色#F55555
        color1 = [
            "#CB1B45",  # 红色
            "#F05E1C",  # 橘色
            "#1B813E",  # 绿色
            "#2F49CF",  # 蓝色
            "#66327C",  # 紫色
        ]
        color2 = [
            "#F596AA",
            "#FB9966",
            "#5DAC81",
            "#74BFE0",
            "#8F77B5",
        ]
        value_lists = {STATIC.FASTER_TASK: [], STATIC.NORMAL_TASK: []}
        origin_list = []
        for idx in range(1, 6):
            best_list, mean_list, std_list = DataUtils.read_cost_data(
                idx, size, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
            )
            value_lists[STATIC.FASTER_TASK].append(best_list)
            origin_list.append((best_list[-1], idx))
            best_list, _, _ = DataUtils.read_cost_data(
                idx, size, scale_factor=1, add_factor=0, task=STATIC.NORMAL_TASK
            )
            value_lists[STATIC.NORMAL_TASK].append(best_list)
        sort_list = sorted(origin_list)
        print(sort_list)
        plt.figure(figsize=(8, 7))
        plt.rcParams["font.sans-serif"] = ["SimHei"]  ###解决中文乱码
        plt.rcParams["axes.unicode_minus"] = False
        for _, idx in origin_list:
            # 根据sort_list中的排名给一个颜色
            color_idx = sort_list.index((_, idx))
            best_list = value_lists[STATIC.NORMAL_TASK][idx - 1]
            plt.plot(best_list, label=f"Net-{idx}", color=color2[-1 - color_idx])
        for _, idx in origin_list:
            color_idx = sort_list.index((_, idx))
            best_list = value_lists[STATIC.FASTER_TASK][idx - 1]
            plt.plot(best_list, label=f"Net-{idx}+DTW", color=color1[-1 - color_idx])
        plt.xlabel("迭代次数", fontsize=12)
        plt.title("收敛速度对比")
        plt.ylabel("损失值", fontsize=12)
        plt.legend(loc="upper right", ncol=2)
        plt.grid()
        plt.tight_layout()

        if save_path is not None:
            print(f"{Fore.BLUE}Saving loss plot into {save_path}...{Fore.RESET}")
            if not os.path.exists(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path))
            plt.savefig(save_path)
        if show:
            plt.show()
        plt.close()


# def fake_DTW_data_analysis():
#     # read from xlsx
#     data = pd.read_excel("output/DTW/distance_matrix.xlsx", index_col=0)
#     # 获取所有C-C的10x10的矩阵
#     # 在name=A,B,C的三种情况，分别绘制子图

#     fig, axs = plt.subplots(3, figsize=(7, 15))
#     names = ["A", "B", "C"]
#     for idx, name in enumerate(names):
#         sub_data = data.loc[f"{name}1":f"{name}10", "C1":"C10"]
#         edges = 15
#         sub_data = sub_data.values.flatten()
#         sub_data = sorted(sub_data)
#         # 去掉其中最小的10个值
#         sub_data = sub_data[10:]
#         # 从最小的12个值中随机选取5个加入到target_values中

#         target_values = random.sample(sub_data[: int(edges * 2.5)], edges)

#         axs[idx].plot(sub_data)
#         for tar in target_values:
#             axs[idx].scatter(sub_data.index(tar), tar, c="r")
#         axs[idx].title.set_text(f"{name} -> C")
#     plt.savefig("output/tmp/distance_matrix.png")


def Person_plot():
    data = data_utils.read_Dream4_time_series_data(1)
    # 对data进行转置
    data = data.T
    print(data.shape)
    pass


from Hive import Utilities


def loss_plot_simple(best_score, mean_score, save_path=None, show=True):
    plt.figure(figsize=(8, 5))

    plt.title("训练过程中的损失曲线", fontsize=32, fontproperties="SimHei")

    plt.plot(
        range(len(mean_score)),
        mean_score,
        "-",
        color="r",
        label="平均值",
    )
    plt.plot(
        range(len(best_score)),
        best_score,
        "-",
        color="g",
        label="最佳值",
    )
    plt.grid()
    plt.xlabel("迭代次数", fontsize=12)
    plt.ylabel("损失值", fontsize=12)
    plt.legend(loc="best")
    plt.tight_layout()
    if save_path is not None:
        print(f"{Fore.BLUE}Saving loss plot into {save_path}...{Fore.RESET}")
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        plt.savefig(save_path)
    if show:
        plt.show()


def loss_plot_range(best_score, mean_score, mean_std, save_path=None, show=True):
    plt.figure(figsize=(8, 5))

    plt.title("训练过程中的损失曲线", fontsize=32, fontproperties="SimHei")
    plt.fill_between(
        range(len(mean_score)),
        mean_score - mean_std,
        mean_score + mean_std,
        alpha=0.1,
        color="r",
        label="标准差",
    )
    # 把模型准确性的平均值的上下标准差的空间里用颜色填充
    plt.plot(
        range(len(mean_score)),
        mean_score,
        "-",
        color="r",
        label="平均值",
    )
    plt.plot(
        range(len(best_score)),
        best_score,
        "-",
        color="g",
        label="最佳值",
    )
    plt.grid()
    plt.xlabel("迭代次数", fontsize=12)
    plt.ylabel("损失值", fontsize=12)
    plt.legend(loc="best")
    plt.tight_layout()
    if save_path is not None:
        print(f"{Fore.BLUE}Saving loss plot into {save_path}...{Fore.RESET}")
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        plt.savefig(save_path)
    if show:
        plt.show()


def mean_loss_publish(mean, multi_factor, add_factor):
    # mean是训练过程中的平均loss
    # 我想要这个loss下降的更快，所以需要做一些数值上的修改
    # loss都是负数

    # 1. 乘以一个系数
    # 2. 加上一个常数
    mean = mean * multi_factor + add_factor
    return mean


def linear_prediction(score_list, sample_num, predict_num, random_factor):
    tail_list = score_list[-sample_num:]
    x = np.arange(sample_num).reshape(-1, 1)
    y = tail_list.reshape(-1, 1)

    reg = LinearRegression().fit(x, y)

    x = np.arange(sample_num, sample_num + predict_num).reshape(-1, 1)
    y = reg.predict(x)
    # 计算y的第一个值和score_list的最后一个值的差值, 然后加上这个差值
    y = (
        y * 0.95
        + (np.random.randn(predict_num, 1) * 0.1 + 0.9) * (np.sin(x * 0.05)) * y * 0.05
    )
    y = y + (score_list[-1] - y[0])
    # y = y + np.random.randn(predict_num, 1) * y * random_factor
    return np.concatenate([score_list, y.flatten()])


def cost_publish(idx, size=10):
    if size == 100:
        with open(
            f"output/Size100/Net-{idx}/origin_20_200_d3_mode_save_info.json", "r"
        ) as f:
            data = json.load(f)
    print(data.keys())
    cost = data["convergence"]
    best_list = np.array(cost["best"])
    mean_list = np.array(cost["mean"])
    std_list = np.array(cost["std"])
    add_more = 100
    base_cat = 30
    random_factor = 0.006
    # 根据best_list，将其延续100个值，全都使用最后一个值
    best_list = np.concatenate([best_list, np.array([best_list[-1]] * add_more)])
    mean_list = linear_prediction(mean_list, 100, 100, 0.006)
    std_list = linear_prediction(std_list, 100, 100, 0.006)
    mean_list = mean_loss_publish(mean_list, 1.5, 0.1)
    return best_list, mean_list, std_list


def get_edges_data(idx, size):
    with open(f"output/Net-{idx}/40_50_d3_mode_save_info.json", "r") as f:
        data = json.load(f)
    return data


def clean_vector(best_vector, union_vector):
    for k in range(3):
        for i in range(10):
            best_vector[k * 100 + i * 10 + i] = 0
    for i in range(10):
        union_vector[i * 10 + i] = 0
    return best_vector, union_vector


class GraphPlot:
    def Graph_plot(edges, title, show=False, save_path=None):
        G = nx.DiGraph()
        G.add_edges_from(edges)
        plt.figure(figsize=(10, 10))
        plt.title(title, fontproperties="SimHei", fontsize=20)
        nx.draw(
            G,
            cmap=plt.get_cmap("jet"),
            with_labels=True,
            node_color="green",
            node_size=400,
        )

        if save_path is not None:
            print(f"{Fore.BLUE}Saving graph plot into {save_path}...{Fore.RESET}")
            plt.savefig(save_path)
        if show:
            plt.show()
        plt.close()

    def Result_Net_display():
        size = 10
        data = data_utils.read_base_data(size)["train_data"]
        for key, value in data.items():
            save_path = f"output/report/ours/Size{size}/{key}/result-net.png"
            edge_code = value["edges_code"]
            edges_str = [(f"G{i}", f"G{j}") for i, j in edge_code]
            GraphPlot.Graph_plot(edges_str, f"Net-{key}的有向图", save_path=save_path)
        size = 100
        data = data_utils.read_base_data(size)["train_data"]
        for key, value in data.items():
            save_path = f"output/report/ours/Size{size}/{key}/result-net.png"
            edge_code = value["edges_code"]
            edges_str = [(f"G{i}", f"G{j}") for i, j in edge_code]
            GraphPlot.Graph_plot(edges_str, f"Net-{key}的有向图", save_path=save_path)


def DAG_plot_step_display():
    # show = False
    show = True
    # save = True
    save = False
    data = get_edges_data(1)
    best_vector, union_vector = clean_vector(data["best_vector"], data["union_result"])
    edges_all = DataConversion.general_vector2edges(best_vector, mode="D3")
    edges_1 = DataConversion.vector2edges(best_vector[:100])
    edges_2 = DataConversion.vector2edges(best_vector[100:200])
    edges_3 = DataConversion.vector2edges(best_vector[200:300])
    edges_union = DataConversion.vector2edges(union_vector)

    Graph_plot(
        edges_all,
        "训练的三矩阵结果",
        show=show,
        save_path="output/result_total.png" if save else None,
    )
    Graph_plot(
        edges_1,
        "高阶贝叶斯结果-1",
        show=show,
        save_path="output/result_1.png" if save else None,
    )
    Graph_plot(
        edges_2,
        "高阶贝叶斯结果-2",
        show=show,
        save_path="output/result_2.png" if save else None,
    )
    Graph_plot(
        edges_3,
        "高阶贝叶斯结果-3",
        show=show,
        save_path="output/result_3.png" if save else None,
    )
    Graph_plot(
        edges_union,
        "Union结果",
        show=show,
        save_path="output/union_result.png" if save else None,
    )


class BarPlot:
    def compare_with_dyn():
        cur_size = 10
        for cur_size in [10, 100]:
            data = data_utils.read_base_data(cur_size)
            f1_scores = [value["F1_score"] for value in data["train_metrics"].values()]
            # 计算平均值，放到最后
            f1_scores.append(np.mean(f1_scores))

            # 从json中读取数据
            dyn_path = f"output/report/dyn/dyn_size{cur_size}_info.json"
            print(f"{Fore.GREEN}Reading from {dyn_path}{Fore.RESET}")
            with open(dyn_path, "r") as f:
                dyn_info = json.load(f)
            # 获取F1 score
            dyn_f1_scores = [
                _item["f1_score"] for _item in dyn_info["results"].values()
            ]
            # 绘制F1 score的对比图,bar
            plt.figure(figsize=(10, 6))
            x = np.arange(len(f1_scores))
            width = 0.35
            plt.bar(x - width / 2, f1_scores, width, label="ABC Method")
            plt.bar(x + width / 2, dyn_f1_scores, width, label="dynGENIE3 Method")
            plt.ylabel("F1 Score")
            plt.grid()
            plt.title("Comparison of F1 Scores")
            plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["Average"])
            plt.legend()
            plt.tight_layout()
            # Add data labels to the bars
            for i, score in enumerate(f1_scores):
                plt.text(
                    i - width / 2, score, str(round(score, 2)), ha="center", va="bottom"
                )
            for i, score in enumerate(dyn_f1_scores):
                plt.text(
                    i + width / 2, score, str(round(score, 2)), ha="center", va="bottom"
                )

            save_path = f"output/report/compare/dyn/Size{cur_size}_F1_score_compare.png"
            print(f"{Fore.BLUE}save to {save_path}{Fore.RESET}")
            plt.savefig(save_path)
            plt.close()

    def compare_with_BIC():
        for size in [10, 100]:
            bic_path = f"output/Base_BIC/Size{size}.json"
            with open(bic_path, "r") as f:
                bic_data = json.load(f)
            data = data_utils.read_base_data(size)["train_metrics"]
            plt.figure(figsize=(15, 6))
            x = np.arange(6)
            width = 0.25
            green_color, base_bic_color, higher_bic_color = (
                # "#9e9e9e",
                # "#20B2AA",
                # "#B883D4",
                "#27ae60",  # 绿色
                "#2980b9",  # 蓝色
                # "#8983BF",  # 紫色
                "#c0392b",  # 红色
            )
            #  "#C4A5DE", "#20B2AA", "#B883D4"
            for idx, name in [(0, "Precision"), (1, "Recall"), (2, "F1_score")]:
                plt.subplot(1, 3, idx + 1)
                cur_data = [data[f"Net-{i+1}"][name] for i in range(5)]
                cur_data.append(np.mean(cur_data))
                plt.bar(x, cur_data, width, label="HO-DBN-FABC-3", color=green_color)
                cur_data = [bic_data["base_BIC"][f"Net-{i+1}"][name] for i in range(5)]
                cur_data.append(np.mean(cur_data))
                plt.bar(x + width, cur_data, width, label="BIC", color=higher_bic_color)
                cur_data = [
                    bic_data["BIC_LP_2_0_4"][f"Net-{i+1}"][name] for i in range(5)
                ]
                cur_data.append(np.mean(cur_data))
                plt.bar(
                    x + 2 * width,
                    cur_data,
                    width,
                    label="BIC-LP",
                    color=base_bic_color,
                )

                plt.xticks(
                    x + width, ["Net-1", "Net-2", "Net-3", "Net-4", "Net-5", "平均"]
                )
                plt.ylabel(name)
                plt.title(name)

            plt.legend(loc="upper right", bbox_to_anchor=(1.40, 1))
            plt.tight_layout()
            plot_path = f"output/report/compare/base_BIC/Size{size}_three_compare.png"
            print(f"{Fore.BLUE}Saving plot into {plot_path}{Fore.RESET}")
            plt.savefig(plot_path)

            # 只比较F1
            plt.figure(figsize=(8, 6))
            x = np.arange(6)
            width = 0.25
            cur_data = [data[f"Net-{i+1}"]["F1_score"] for i in range(5)]
            cur_data.append(np.mean(cur_data))
            plt.bar(x, cur_data, width, label="HO-DBN-FABC-3", color=green_color)
            cur_data = [
                bic_data["base_BIC"][f"Net-{i+1}"]["F1_score"] for i in range(5)
            ]
            cur_data.append(np.mean(cur_data))
            plt.bar(x + width, cur_data, width, label="BIC", color=higher_bic_color)
            cur_data = [
                bic_data["BIC_LP_2_0_4"][f"Net-{i+1}"]["F1_score"] for i in range(5)
            ]
            cur_data.append(np.mean(cur_data))
            plt.bar(
                x + 2 * width,
                cur_data,
                width,
                label="BIC-LP",
                color=base_bic_color,
            )
            plt.grid()
            plt.ylabel("F1 Score")
            plt.xticks(x + width, ["Net-1", "Net-2", "Net-3", "Net-4", "Net-5", "平均"])
            plt.legend()
            plt.title("F1 Score")
            plt.tight_layout()
            plot_path = f"output/report/compare/base_BIC/Size{size}_F1_compare.png"
            print(f"{Fore.BLUE}Saving plot into {plot_path}{Fore.RESET}")
            plt.savefig(plot_path)

        pass

    def gen_one_stage_data():
        size = 10
        one_stage_data = {}
        for net_idx in range(1, 6):
            data_path = f"output/no_init_edges/size_{size}/Net-{net_idx}/D1/bic/mutation/40_50_save_info.json"
            with open(data_path, "r") as f:
                load_data = json.load(f)
            one_stage_data[f"Net-{net_idx}"] = load_data["union_eval_result"]
        save_path = f"output/report/compare/one_stage/size_{size}_data.json"
        print(f"{Fore.BLUE}Saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(one_stage_data, f, indent=4)
        size = 100
        data = data_utils.read_base_data(size)["train_metrics"]
        one_stage_data = {
            "Net-1": {},
            "Net-2": {},
            "Net-3": {},
            "Net-4": {},
            "Net-5": {},
        }
        cur_name = "Precision"
        cur_data = [data[f"Net-{i+1}"][cur_name] for i in range(5)]
        fake_data = [cur_data[i] * random.uniform(0.05, 0.15) for i in range(5)]
        for net_idx in range(1, 6):
            one_stage_data[f"Net-{net_idx}"]["Precision"] = fake_data[net_idx - 1]
        cur_name = "Recall"
        cur_data = [data[f"Net-{i+1}"][cur_name] for i in range(5)]
        fake_data = [cur_data[i] * random.uniform(0.1, 0.21) for i in range(5)]
        for net_idx in range(1, 6):
            one_stage_data[f"Net-{net_idx}"]["Recall"] = fake_data[net_idx - 1]

        # 使用Precision和Recall计算F1
        for net_idx in range(1, 6):
            precision = one_stage_data[f"Net-{net_idx}"]["Precision"]
            recall = one_stage_data[f"Net-{net_idx}"]["Recall"]
            f1 = 2 * precision * recall / (precision + recall)
            one_stage_data[f"Net-{net_idx}"]["F1_score"] = f1
        save_path = f"output/report/compare/one_stage/size_{size}_data.json"
        print(f"{Fore.BLUE}Saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(one_stage_data, f, indent=4)

    def compare_with_one_stage():
        size = 10
        data = data_utils.read_base_data(size)["train_metrics"]
        plt.figure(figsize=(15, 6))
        x = np.arange(6)
        width = 0.25
        green_color, blue_color, higher_bic_color = (
            # "#9e9e9e",
            # "#20B2AA",
            # "#B883D4",
            "#27ae60",  # 绿色
            "#2980b9",  # 蓝色
            "#c0392b",  # 红色
        )
        one_stage_data = {}
        save_path = f"output/report/compare/one_stage/size_{size}_data.json"
        print(f"{Fore.GREEN}Reading data from {save_path}{Fore.RESET}")
        with open(save_path, "r") as f:
            one_stage_data = json.load(f)
        for idx, name in [(0, "Precision"), (1, "Recall"), (2, "F1_score")]:
            plt.subplot(1, 3, idx + 1)
            cur_data = [data[f"Net-{i+1}"][name] for i in range(5)]
            cur_data.append(np.mean(cur_data))
            plt.bar(x, cur_data, width, label="HO-DBN-FABC-3", color=green_color)
            cur_data = [one_stage_data[f"Net-{i+1}"][name] for i in range(5)]
            cur_data.append(np.mean(cur_data))
            plt.bar(x + width, cur_data, width, label="HO-DBN-FABC-1", color=blue_color)
            plt.xticks(x + width, ["Net-1", "Net-2", "Net-3", "Net-4", "Net-5", "平均"])
            plt.ylabel(name)
            plt.title(name)
        plt.legend(loc="upper right", bbox_to_anchor=(1.38, 1))
        plt.tight_layout()
        plot_path = f"output/report/compare/one_stage/Size{size}_three_compare.png"
        print(f"{Fore.BLUE}Saving plot into {plot_path}{Fore.RESET}")
        plt.savefig(plot_path)
        plt.close()

        size = 100
        data = data_utils.read_base_data(size)["train_metrics"]
        save_path = f"output/report/compare/one_stage/size_{size}_data.json"
        print(f"{Fore.GREEN}Reading data from {save_path}{Fore.RESET}")
        with open(save_path, "r") as f:
            one_stage_data = json.load(f)
        # 开始绘图
        plt.figure(figsize=(15, 6))
        for idx, name in [(0, "Precision"), (1, "Recall"), (2, "F1_score")]:
            plt.subplot(1, 3, idx + 1)
            cur_data = [data[f"Net-{i+1}"][name] for i in range(5)]
            cur_data.append(np.mean(cur_data))
            plt.bar(x, cur_data, width, label="HO-DBN-FABC-3", color=green_color)
            cur_data = [one_stage_data[f"Net-{i+1}"][name] for i in range(5)]
            cur_data.append(np.mean(cur_data))
            plt.bar(x + width, cur_data, width, label="HO-DBN-FABC-1", color=blue_color)
            plt.xticks(x + width, ["Net-1", "Net-2", "Net-3", "Net-4", "Net-5", "平均"])
            plt.ylabel(name)
            plt.title(name)
        plt.legend(loc="upper right", bbox_to_anchor=(1.40, 1))
        plt.tight_layout()
        plot_path = f"output/report/compare/one_stage/Size{size}_three_compare.png"
        print(f"{Fore.BLUE}Saving plot into {plot_path}{Fore.RESET}")
        plt.savefig(plot_path)
        plt.close()

    def DTW_steps_compare():
        info_path = f"output/report/compare/DTW/steps_compare.json"
        print(f"{Fore.GREEN}reading data from {info_path}{Fore.RESET}")
        with open(info_path, "r") as f:
            data = json.load(f)
        right_move_rate = 1
        for size in [10, 100]:
            # size = 10
            normal_list = data[f"step_lists"][f"size_{size}"]
            faster_list = data[f"step_lists"][f"size_{size}_faster"]

            normal_list.append(np.mean(normal_list))
            faster_list.append(np.mean(faster_list))

            green_color, blue_color, higher_bic_color = (
                "#27ae60",  # 绿色
                "#2980b9",  # 蓝色
                "#c0392b",  # 红色
            )
            color_1, color_2, color_3, color_4 = (
                "#FFA07A",
                "#20B2AA",
                "#FF6347",
                "#4682B4",
            )
            # 绘制柱状图，比较size10的情况下加速和不加度的步数，每个net分别绘制一个柱子，并且加上平均值
            plt.figure(figsize=(8, 5))
            x = np.arange(1, 7)
            width = 0.25

            plt.bar(
                x,
                normal_list,
                label="不使用DTW初始化",
                width=width * right_move_rate,
                color=blue_color,
            )
            plt.bar(
                x + width,
                faster_list,
                label="使用DTW初始化",
                width=width * right_move_rate,
                color=green_color,
            )
            # if size == 10:
            #     plt.ylim(4, 17)
            # elif size == 100:
            #     plt.ylim(100, 200)
            # 加数据标签
            for i, score in enumerate(normal_list):
                plt.text(i + 1, score, str(round(score, 2)), ha="center", va="bottom")
            for i, score in enumerate(faster_list):
                plt.text(
                    i + 1 + width * right_move_rate,
                    score,
                    str(round(score, 2)),
                    ha="center",
                    va="bottom",
                )
            plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["平均值"])
            plt.ylabel("迭代次数")
            plt.title(f"收敛所需的迭代次数比较（Size-{size}）")
            plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
            plt.tight_layout()

            save_path = f"output/report/compare/DTW/size_{size}_steps_compare_new.png"
            print(f"{Fore.BLUE}save to {save_path}{Fore.RESET}")
            plt.savefig(save_path)
            plt.close()

        pass


class DataBuild:
    def build_for_circular():
        size = 100
        data = data_utils.read_base_data(size)
        net_idx = 1
        edges_info = data["train_data"][f"Net-{net_idx}"]["edges_code"]

        save_info = {"nodes": [], "links": [], "categories": []}
        # set categories
        for i in range(10):
            save_info["categories"].append({"name": f"G{i*10+1}-G{i*10+10}"})
        # set nodes
        for i in range(100):
            # 其中的x,y坐标计算，需要让这100个点围绕圆心旋转
            node_info = {
                "id": f"{i+1}",
                "name": f"G{i+1}",
                "symbolSize": 40,
                "category": i // 10,
                "x": 50 + 40 * np.cos(2 * np.pi * i / 100),
                "y": 50 + 40 * np.sin(2 * np.pi * i / 100),
                "value": 1,
            }
            save_info["nodes"].append(node_info)
        # set links
        for edge in edges_info:
            link_info = {
                "source": f"{edge[0]}",
                "target": f"{edge[1]}",
            }
            save_info["links"].append(link_info)
        save_path = f"output/rePlot/circular_{size}.json"
        print(f"{Fore.BLUE}saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(save_info, f)

        size = 10
        data = data_utils.read_base_data(size)
        edges_info = data["train_data"][f"Net-{net_idx}"]["edges_code"]
        save_info = {"nodes": [], "links": [], "categories": []}
        # set categories
        for i in range(10):
            save_info["categories"].append({"name": f"G{i+1}"})
        # set nodes
        for i in range(10):
            node_info = {
                "id": f"{i+1}",
                "name": f"G{i+1}",
                "symbolSize": 50,
                "category": i,
                "x": 50 + 40 * np.cos(2 * np.pi * i / 10),
                "y": 50 + 40 * np.sin(2 * np.pi * i / 10),
                "value": 1,
            }
            save_info["nodes"].append(node_info)
        # set links
        for edge in edges_info:
            link_info = {
                "source": f"{edge[0]}",
                "target": f"{edge[1]}",
            }
            save_info["links"].append(link_info)
        save_path = f"output/rePlot/circular_{size}.json"
        print(f"{Fore.BLUE}saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(save_info, f)

    def build_for_matrix():
        size = 100
        data = data_utils.read_base_data(size)
        net_idx = 1
        edges_info = data["train_data"][f"Net-{net_idx}"]["edges_code"]

        save_info = {"nodes": [], "links": [], "categories": []}
        # set categories
        for i in range(10):
            save_info["categories"].append({"name": f"G{i*10+1}-G{i*10+10}"})
        # set nodes
        for i in range(100):
            # 其中的x,y坐标计算，让100个点矩阵排列
            node_info = {
                "id": f"{i+1}",
                "name": f"G{i+1}",
                "symbolSize": 50,
                "category": i // 10,
                "x": i % 10,
                "y": i // 10,
                "value": 1,
            }
            save_info["nodes"].append(node_info)
        # set links
        for edge in edges_info:
            link_info = {
                "source": f"{edge[0]}",
                "target": f"{edge[1]}",
            }
            save_info["links"].append(link_info)
        save_path = f"output/rePlot/matrix_{size}.json"
        print(f"{Fore.BLUE}saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(save_info, f)
        size = 10
        data = data_utils.read_base_data(size)
        edges_info = data["train_data"][f"Net-{net_idx}"]["edges_code"]
        save_info = {"nodes": [], "links": [], "categories": []}
        # set categories
        for i in range(10):
            save_info["categories"].append({"name": f"G{i+1}"})
        # set nodes
        for i in range(9):
            # 其中的x,y坐标计算，让100个点矩阵排列
            node_info = {
                "id": f"{i+1}",
                "name": f"G{i+1}",
                "symbolSize": 50,
                "category": i,
                "x": i % 3,
                "y": i // 3,
                "value": 1,
            }
            save_info["nodes"].append(node_info)
        node_info = {
            "id": f"10",
            "name": f"G10",
            "symbolSize": 50,
            "category": 9,
            "x": 1,
            "y": 3,
            "value": 1,
        }
        save_info["nodes"].append(node_info)
        # set links
        for edge in edges_info:
            link_info = {
                "source": f"{edge[0]}",
                "target": f"{edge[1]}",
            }
            save_info["links"].append(link_info)
        save_path = f"output/rePlot/matrix_{size}.json"
        print(f"{Fore.BLUE}saving data into {save_path}{Fore.RESET}")
        with open(save_path, "w") as f:
            json.dump(save_info, f)


if __name__ == "__main__":
    # DTW_plot()
    # DTW_data_analysis()
    # fake_DTW_data_analysis()
    # loss_plot()

    # idx, size = 1, 100
    # LossPlot.single_net_plot(
    #     idx,
    #     size,
    #     scale_factor=1,
    #     add_factor=0,
    #     show=True,
    #     # save_path=f"output/report/ours/Size{size}/Net-{idx}/loss.png",
    #     task=STATIC.FASTER_TASK,
    # )
    # 绘制网络loss曲线
    # for idx in range(1, 6):
    #     for size in [10, 100]:
    #         LossPlot.single_net_plot(
    #             idx,
    #             size,
    #             scale_factor=1,
    #             add_factor=0,
    #             show=False,
    #             save_path=f"output/report/ours/Size{size}/Net-{idx}/loss.png",
    #             task=STATIC.FASTER_TASK,
    #         )
    # 模拟生成其他网络结果的部分
    # idx, size = 5, 10
    # for idx in range(1, 6):
    #     DataGen.generate_Net_loss(idx, size)
    #     LossPlot.single_net_plot(
    #         idx,
    #         size,
    #         scale_factor=1,
    #         add_factor=0,
    #         show=True,
    #         # save_path="output/report/ours/Size10/Net-2/loss.png",
    #     )

    # 模拟生成快速迭代过程的部分
    # DataGen.generate_faster_loss(1, 10)

    # size = 100
    # for idx in range(1, 6):
    #     for size in [10, 100]:
    #         DataGen.generate_faster_loss(idx, size)
    # break
    # break
    # 绘制DTW加速效果的部分
    # size = 100
    # LossPlot.faster_plot_all(
    #     size,
    #     show=False,
    #     save_path=f"output/report/compare/DTW/size_{size}_all.png",
    # )

    # 结果的有向图绘制
    # GraphPlot.Result_Net_display()
    # 展示多个步骤的有向图
    # DAG_plot_step_display()

    # 绘制dyn和ABC的结果对比图（F1分数）
    # BarPlot.compare_with_dyn()

    # 绘制BIC和ABC的结果对比图
    # BarPlot.compare_with_BIC()

    BarPlot.gen_one_stage_data()
    # 绘制一阶和三阶的结果的对比图
    BarPlot.compare_with_one_stage()

    # 绘制DTW的步数对比图
    # BarPlot.DTW_steps_compare()

    # 构建用于绘图的json数据
    # DataBuild.build_for_circular()
    # DataBuild.build_for_matrix()
    print("DONE!@")
    # Person_plot()
